Source code for hysop.tools.callback

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
from hysop.tools.units import unit2str, time2str, bytes2str, bdw2str
from hysop.backend.device.codegen.base.statistics import WorkStatistics


def _to_list(e):
    if isinstance(e, list):
        return e
    elif isinstance(e, tuple) or isinstance(e, set):
        return list(e)
    elif isinstance(e, np.ndarray):
        return e.tolist()
    else:
        return [e]


[docs] class TimerInterface: def __init__(self, **kargs): super().__init__(**kargs) self.state = None self.min = None self.max = None self.nruns = 0 self.data = []
[docs] def mean(self): if self.nruns == 0: return None else: return sum(self.data) / float(self.nruns)
[docs] def total(self): if self.nruns == 0: return 0 else: return sum(self.data)
[docs] def status(self): if self.state is None: # waiting 1st run return "W" elif self.state < 0: # running return "R" else: # sleeping return "S"
[docs] def register_timing(self, timing): if self.min is None: self.min = timing elif timing < self.min: self.min = timing if self.max is None: self.max = timing elif timing > self.max: self.max = timing self.nruns += 1 self.data.append(timing)
def __str__(self): return "({}) nruns={:4d}, min={}, max={}, mean={}, total={}".format( self.status(), self.nruns, time2str(self.min), time2str(self.max), time2str(self.mean()), time2str(self.total()), ) @staticmethod def _as_group(groupname, tasks, tic_callbacks=[], tac_callbacks=[]): return TimingGroup( name=groupname, tasks=tasks, tic_callbacks=tic_callbacks, tac_callbacks=tac_callbacks, )
[docs] class MemInterface(TimerInterface): def __init__(self, membytes, **kargs): super().__init__(**kargs) self.membytes = membytes self.min_bandwidth = None self.max_bandwidth = None self.bandwidth = []
[docs] def register_timing(self, timing): super().register_timing(timing) bdw = self.membytes / float(timing) if self.min_bandwidth is None: self.min_bandwidth = bdw elif bdw < self.min_bandwidth: self.min_bandwidth = bdw if self.max_bandwidth is None: self.max_bandwidth = bdw elif bdw > self.max_bandwidth: self.max_bandwidth = bdw self.bandwidth.append(bdw)
[docs] def mean_bandwidth(self): if self.nruns == 0: return None else: return sum(self.bandwidth) / float(self.nruns)
[docs] def total_mem_moved(self): return self.nruns * self.membytes
def __str__(self): s = "\n{:15s} min_bdw={}, max_bdw={}, mean_bdw={}, total_mem_moved={}".format( "", bdw2str(self.min_bandwidth), bdw2str(self.max_bandwidth), bdw2str(self.mean_bandwidth()), bytes2str(self.total_mem_moved()), ) return TimerInterface.__str__(self) + s
[docs] class MemcpyInterface(MemInterface): def __init__(self, membytes, **kargs): super().__init__(membytes=membytes, **kargs) def __str__(self): s = "\n{:15s} min_bdw={}, max_bdw={}, mean_bdw={}, total_mem_moved={}".format( "", bdw2str(2 * self.min_bandwidth), bdw2str(2 * self.max_bandwidth), bdw2str(2 * self.mean_bandwidth()), bytes2str(self.total_mem_moved()), ) return TimerInterface.__str__(self) + s
[docs] class ComputeInterface(MemInterface): def __init__(self, total_work, per_work_statistic, ftype="float", **kargs): if not isinstance(per_work_statistic, WorkStatistics): raise ValueError("per_work_statistic is not a WorkStatistics") if total_work < 1: raise ValueError("total_work < 1.") membytes = total_work * per_work_statistic.global_mem_transactions() super().__init__(membytes=membytes, **kargs) self.ftype = ftype self.total_work = total_work self.per_work_statistic = per_work_statistic self.total_work_statistic = WorkStatistics()
[docs] def register_timing(self, timing): super().register_timing(timing) self.total_work_statistic += self.total_work * self.per_work_statistic
[docs] def stats_per_second(self): if self.nruns == 0: return None else: return self.total_work_statistic.compute_timed_statistics(self.total())
def __str__(self): s = "" timed_stats = self.stats_per_second() if timed_stats is not None: if self.ftype == "float": float_op_category = "FLOPS (FP32)" float_op_factor = 1.0 elif self.ftype == "double": float_op_category = "FLOPS (FP64)" float_op_factor = 0.5 elif self.ftype == "half": float_op_category = "FLOPS (FP16)" float_op_factor = 2.0 else: raise ValueError() flops = timed_stats.ops_per_category()["FLOPS"] flops *= float_op_factor opi = flops / timed_stats.global_mem_transactions() if timed_stats.global_mem_throughput() > 0: s += f" throughput={bdw2str(timed_stats.global_mem_throughput())}" if ( timed_stats.global_mem_throughput() < timed_stats.total_mem_throughput() ): s += f" (tot={bdw2str(timed_stats.total_mem_throughput())})" s += " OPI={}".format(unit2str(opi, "FLOP/B", decimal=True, rounded=2)) for op_category, ops_per_second in timed_stats.ops_per_second().items(): if op_category != "FLOPS": s += f" {unit2str(ops_per_second,op_category,decimal=True,rounded=2)}" else: s += f" {unit2str(ops_per_second*float_op_factor,float_op_category,decimal=True,rounded=2)}" return TimerInterface.__str__(self) + s
[docs] class CallbackTask: def __init__(self, name, tic_callbacks=[], tac_callbacks=[], **kargs): super().__init__(**kargs) self.name = name self.tic_callbacks = [] self.tac_callbacks = [] self.register_callbacks(tic_callbacks, tac_callbacks)
[docs] def tic(self, **kargs): self._on_tic(**kargs) for cb in self.tic_callbacks: cb(self, **kargs)
[docs] def tac(self, **kargs): self._on_tac(**kargs) for cb in self.tac_callbacks: cb(self, **kargs)
[docs] def register_callbacks(self, tic_callbacks=[], tac_callbacks=[]): tic_callbacks = _to_list(tic_callbacks) tac_callbacks = _to_list(tac_callbacks) for cb in tic_callbacks: if cb not in self.tic_callbacks: self.tic_callbacks.append(cb) for cb in tac_callbacks: if cb not in self.tac_callbacks: self.tac_callbacks.append(cb)
def _on_tic(self, **kargs): msg = f"_on_tic not implemented in class {self.__class__.__name__}." raise NotImplementedError(msg) def _on_tac(self, **kargs): msg = f"_on_tac not implemented in class {self.__class__.__name__}." raise NotImplementedError(msg) def __str__(self): msg = f"__str__ not implemented in class {self.__class__.__name__}." raise NotImplementedError(msg)
[docs] def report(self, offset): return self.offset_str(offset) + f"{self.name:15s}"
[docs] @staticmethod def offset_str(count): return " " * count
[docs] @staticmethod def format(s, count): return s.replace("\n", "\n " + CallbackTask.offset_str(count))
[docs] class CallbackGroup(CallbackTask): def __init__(self, name, tasks, **kargs): super().__init__(name, **kargs) self.tasks = tasks self.ticked = np.zeros(shape=(len(tasks),), dtype=bool) self.tacked = self.ticked.copy() taskid = {} for i, task in enumerate(tasks): taskid[task.name] = i self._taskid = taskid self._check() def _on_task_tic(task, **args): tid = self.taskid(task) self.ticked[tid] = True if self.ticked.all(): super(CallbackGroup, self).tic(**args) def _on_task_tac(task, **args): tid = self.taskid(task) self.tacked[tid] = True if self.tacked.all(): super(CallbackGroup, self).tac(**args) self.ticked[:] = False self.tacked[:] = False for task in tasks: task.register_callbacks( tic_callbacks=_on_task_tic, tac_callbacks=_on_task_tac ) def _check(self): if len(self.tasks) == 0: raise ValueError("Empty task list!")
[docs] def taskid(self, task): return self._taskid[task.name]
[docs] def tic(self, **kargs): raise RuntimeError("CallbackGroup.tic() should not be called explicitely.")
[docs] def tac(self, **kargs): raise RuntimeError("CallbackGroup.tac() should never be called explicitely.")
[docs] def report(self, offset=0): s = "" s += f"{self.offset_str(offset)}{self.name}" tasks = tuple(sorted(self.tasks, key=lambda x: x.total(), reverse=True)) for task in tasks: s += "\n" + task.report(offset + 1) s += "\n{}{:15s} {}".format( self.offset_str(offset + 1), "total:", self.__str__() ) return s
@classmethod def _as_group(cls, groupname, tasks, tic_callbacks=[], tac_callbacks=[], **kargs): return self.__class__( name=groupname, tasks=tasks, tic_callbacks=tic_callbacks, tac_callbacks=tac_callbacks, **kargs, )
[docs] class TimingGroup(CallbackGroup, TimerInterface): def __init__(self, **kargs): super().__init__(**kargs) self._check() def _check(self): for task in self.tasks: if not isinstance(task, TimerInterface): msg = "{} is not an instance of TimerInteface, \ got {} instead.".format( task.name, task.__class__.__name__ ) raise ValueError(msg) elif not isinstance(task, CallbackTask): msg = "{} is not an instance of CallbackTask, \ got {} instead.".format( task.name, task.__class__.__name__ ) raise ValueError(msg) def _on_tic(self, **kargs): pass def _on_tac(self, **kargs): total = 0 for task in self.tasks: total += task.data[-1] if self.min is None: self.min = total else: self.min = min(self.min, total) if self.max is None: self.max = total else: self.max = max(self.max, total) self.nruns += 1 self.data.append(total) def __str__(self): return TimerInterface.__str__(self)
[docs] class TimingTask(CallbackTask, TimerInterface): def __init__(self, **kargs): super().__init__(**kargs)
[docs] def report(self, offset=0): return f"{CallbackTask.report(self,offset)} {TimerInterface.__str__(self)}"
def __str__(self): return self.report()
[docs] class MemcpyTask(CallbackTask, MemcpyInterface): def __init__(self, MPI, **kargs): super().__init__(**kargs) self.MPI = MPI
[docs] def report(self, offset=0): return "{} {}".format( CallbackTask.report(self, offset), self.format(MemcpyInterface.__str__(self), offset), )
def __str__(self): return self.report() def _on_tic(self, **kargs): self.state = -self.MPI.Wtime() def _on_tac(self, **kargs): self.state += self.MPI.Wtime() self.register_timing(self.state)
[docs] class ComputeTask(CallbackTask, ComputeInterface): def __init__(self, MPI, **kargs): super().__init__(**kargs) self.MPI = MPI
[docs] def report(self, offset=0): return "{} {}".format( CallbackTask.report(self, offset), self.format(ComputeInterface.__str__(self), offset), )
def __str__(self): return self.report() def _on_tic(self, **kargs): self.state = -self.MPI.Wtime() def _on_tac(self, **kargs): self.state += self.MPI.Wtime() self.register_timing(self.state)
[docs] class MPITimingTask(TimingTask): def __init__(self, MPI, **kargs): super().__init__(**kargs) self.MPI = MPI def _on_tic(self, **kargs): self.state = -self.MPI.Wtime() def _on_tac(self, **kargs): self.state += self.MPI.Wtime() self.register_timing(self.state)
[docs] class CallbackProfiler: def __init__(self, MPI): self._MPI = MPI self.tasks = {} self.groups = {} self._tasks_in_group = set()
[docs] def tic(self, target, **kargs): self._check_registered(target)[target].tic(**kargs)
[docs] def tac(self, target, **kargs): self._check_registered(target)[target].tac(**kargs)
[docs] def register_tasks(self, tasks, tic_callbacks=[], tac_callbacks=[], **kargs): tasks = _to_list(tasks) for task in tasks: if isinstance(task, CallbackTask): taskname = task.name self._check_not_registered(taskname) else: taskname = task if taskname in self.tasks: task = self.tasks[task] elif "per_work_statistic" in kargs: task = ComputeTask(MPI=self._MPI, name=taskname, **kargs) elif "membytes" in kargs: task = MemcpyTask(MPI=self._MPI, name=taskname, **kargs) else: task = MPITimingTask(MPI=self._MPI, name=taskname, **kargs) task.register_callbacks(tic_callbacks, tac_callbacks) self.tasks[taskname] = task
[docs] def register_group(self, groupname, tasknames, tic_callbacks=[], tac_callbacks=[]): if groupname in self.groups: source = set(tasknames) target = {task.name for task in self.groups[groupname].tasks} if source != target: msg = f"Group {groupname} was already registered!" raise RuntimeError(msg) else: # just update callbacks self.groups[groupname].register_callbacks(tic_callbacks, tac_callbacks) else: tasks = [] for taskname in tasknames: if taskname in self.registered_targets(): tasks.append(self._check_registered(taskname)[taskname]) self._tasks_in_group.update([taskname]) group = tasks[0]._as_group(groupname, tasks, tic_callbacks, tac_callbacks) group.register_callbacks(tic_callbacks, tac_callbacks) self.groups[groupname] = group
[docs] def registered_targets(self): return tuple(self.groups.keys()) + tuple(self.tasks.keys())
[docs] def register_callbacks(self, target, tic_callbacks=[], tac_callbacks=[]): dic = self._check_registered(target) dic[target].register_callbacks(tic=tic_callbacks, tac=tac_callbacks)
[docs] def has_tasks(self): return bool(self.tasks)
[docs] def has_groups(self): return bool(self.groups)
def _check_registered(self, target): if target not in self.registered_targets(): msg = f"{target} is not registered as a task or a group." raise ValueError(msg) return self.groups if (target in self.groups.keys()) else self.tasks def _check_not_registered(self, target): if target in self.registered_targets(): msg = f"Target {target} was already registered!" raise ValueError(msg)
[docs] def report(self, mode="recursive"): s = "=== Callback Profiler Report ===" if mode == "all": if self.has_tasks(): s += "\n ::Individual tasks::" for taskname, task in self.tasks.items(): s += "\n" + task.report(1) if self.has_groups(): s += "\n ::Group tasks::" for taskname, task in self.groups.items(): s += "\n" + task.report(1) elif mode == "recursive": if self.has_groups(): groups = tuple( sorted(self.groups.values(), key=lambda x: x.total(), reverse=True) ) for group in groups: s += "\n" + group.report(1) if self.has_tasks(): individual_tasknames = set(self.tasks.keys()).difference( self._tasks_in_group ) tasknames = tuple( sorted( individual_tasknames, key=lambda x: self.tasks[x].total(), reverse=True, ) ) for taskname in tasknames: task = self.tasks[taskname] s += "\n" + task.report(1) return s
def __str__(self): return self.report()